import gym
from gym.wrappers import Monitor
# from option import Option
from env.mujoco_env.reacher_env import ReacherGymEnv
from env.mujoco_env.reacher_env import RMReacherGymEnv
from env.mujoco_env.reacher_env import ReacherGymEnvEval
import os
import torch
import gym
from monitor import Monitor
from option import *
import numpy as np

###############
# Load Option #
###############
# task_name = 'composite'
# nF = 7
# task_name = 'sequential'
# nF = 5
task_name = 'OR'
nF = 3
option_load_path = os.path.join(os.environ['LOF_PKG_PATH'], 'experiments', 'rm', task_name, 'pyt_save', 'model950.pt')

option = Option(option_load_path)

#################
# Construct Env #
#################

env = RMReacherGymEnv(nF=nF,task_name=task_name,training=False,env_config={'headless': False, 'horizon': 800})
env = Monitor(env, './video', video_callable=lambda episode_id: True, force=True)

###############
# Run Rollout #
###############
def run_rollout(policy, env, num_episodes):
    goal_state = env.nF - 1
    max_num_steps = 800

    for i in range(num_episodes):
        task_done = False
        R = 0
        obs = env.reset()

        f = 0
        prev_f = f

        while not task_done:
            prev_f = f
            while prev_f == f and not task_done:
                env.render()

                a = policy.get_action(torch.from_numpy(obs).float())
                obs, reward, task_done, info = env.step(a)
                # print("FSA: {} | Goal: {} | reward {}".format(f, color, reward))
                f = obs[0]
                # print(f)
                prev_f = f
                R += reward

                if f == goal_state:
                    env.set_task_done(True)

        print(f"Episode {i} return: {R} | FSA: {f}")

    env.close()

#######
# Run #        
#######
run_rollout(option, env, num_episodes=10)